import h5py
import glob
from argparse import ArgumentParser
import numpy as np
from blockprocessing import BlockProcessing
from tvem.models import TVAE
from tvem.variational import RandomSampledVarStates
import torch as to
from pathlib import Path
import matplotlib.pyplot as plt
plt.switch_backend('agg')

# Given the original image, the name of the noisy dataset, and a list of
# outputs of runs for that dataset, this program prints PSNR value using
# TVAE model parameters at the epoch of best free energy and variational
# states at the last epoch (the only ones saved by the TVEM framework).


def eval_psnr(fig1, fig2):
    if np.issubdtype(fig1.dtype, np.unsignedinteger):
        fig1 = fig1.astype(np.float32)
    def my_rms(x):
        return np.sqrt(np.mean((x.flatten()) ** 2))

    pamp = 255  # was: np.amax(fig1)
    rms = my_rms(fig1 - fig2)
    if rms == 0.0:
        psnr_ = np.inf
    else:
        psnr_ = 20.0 * np.log10(pamp / rms)
    return psnr_


def parse_args():
    p = ArgumentParser()
    p.add_argument("original", help=".h5 file with original image")
    p.add_argument("dataset_name", help=".h5 file with noisy patches generated from original image")
    p.add_argument("runs", nargs="+", help=".h5 files outputted by the TVEM framework during training on the dataset")
    p.add_argument("--draw-in-dir", required=False, help="if present, draw reconstructed data and save .svg files in specified directory")
    p.add_argument("--noisy-image", required=False, help="if present, name of HDF5 file with noisy image. also draw noisy image contained in specified .h5 file (not patches) -- requires --draw-in-dir")
    return p.parse_args()


def depatchify(reco_patches, image):
    """Take output of TVEM training, depatchify it, evaluate PSNR."""
    mask = np.zeros_like(image)  # 0 -> to reconstruct
    img_copy = image.copy()  # bp modifies the input
    patch_size = int(np.sqrt(reco_patches.shape[1]))

    bp = BlockProcessing(
        img_copy,
        mask=mask,
        patchheight=patch_size,
        patchwidth=patch_size,
        pp_params={"pp_type": None, "sf_type": "gauss_"},
    )
    bp.im2bl()
    bp.Y[:] = reco_patches.T
    bp.bl2im()
    psnr = eval_psnr(image, bp.I)
    return bp.I, psnr


def get_psnr(run, dataset, clean_img, draw_in_dir = None, noisy_image = None):
    Fs = run["train_F"][...]
    best_idx = np.argmax(Fs)
    theta = run["theta"]
    th = {"W_0": to.from_numpy(theta["W_0"][best_idx]),
          "b_0": to.from_numpy(theta["b_0"][best_idx]),
          "W_1": to.from_numpy(theta["W_1"][best_idx]),
          "b_1": to.from_numpy(theta["b_1"][best_idx]),
          "pies": to.from_numpy(theta["pies"][best_idx-1]),
          "sigma2": theta["sigma2"][best_idx-1]}

    tvae = TVAE(W_init=(th["W_0"], th["W_1"]), b_init=(th["b_0"], th["b_1"]),
                pi_init=th["pies"], sigma2_init=th["sigma2"], precision=to.float32)

    H, N = tvae.shape[1], dataset.shape[0]
    S = run["train_states"].shape[1]
    states_conf = dict(precision=tvae.precision, H=H, N=N, S=S)
    states = RandomSampledVarStates(S_new=0, conf=states_conf)
    states.K[:] = to.from_numpy(run["best_train_states"][...])
    data_as_torch = to.from_numpy(dataset).to(tvae.precision)
    batch_size = 32
    for chunk_start in range(0, N, batch_size):
        idxs = to.arange(chunk_start, min(chunk_start + batch_size, N))
        states.lpj[idxs] = tvae.log_pseudo_joint(data_as_torch[idxs], states.K[idxs])

    F = tvae.free_energy(to.arange(N), data_as_torch, states) / N

    batch_size = run["exp_config"]["batch_size"][...].item()
    reco_patches = to.empty((N, tvae.shape[0]), dtype=tvae.precision) 
    for chunk_start in range(0, N, batch_size): 
        idxs = to.arange(chunk_start, min(chunk_start + batch_size, N)) 
        reco_patches[idxs] = tvae.data_estimator(idxs, states)
    
    img, psnr = depatchify(reco_patches.numpy(), clean_img)

    if draw_in_dir is not None:
        d = Path(draw_in_dir)
        d.mkdir(exist_ok=True)

        plt.figure()
        plt.imshow(img, cmap="gray", vmin=0, vmax=255)
        plt.title(f"Reconstructed image (PSNR = {psnr:.2f})")
        ax = plt.gca()
        ax.set_xticks([])
        ax.set_yticks([])
        plt.savefig(str(d / Path(run.filename).stem) + ".svg")

        if args.noisy_image is not None:
            plt.figure()
            with h5py.File(noisy_image, "r") as noisy_img:
                assert noisy_img["data"].shape == img.shape
                plt.imshow(noisy_img["data"][...], cmap="gray", vmin=0, vmax=255)
            ax = plt.gca()
            ax.set_xticks([])
            ax.set_yticks([])
            plt.title(f"Original noisy image")
            plt.savefig(str(d / Path(run.filename).stem) + ".svg")

    return F, psnr, best_idx


if __name__ == "__main__":
    args = parse_args()

    with h5py.File(args.dataset_name, "r") as f:
        noisy_patches = f["data"][...]
    with h5py.File(args.original, "r") as f:
        clean_img = f["data"][...]

    files = args.runs
    best_Fs = []
    PSNRs = []
    idxs = []

    for fname in files:
        with h5py.File(fname, "r") as run:
            if "train_F" not in run:
                raise KeyError(f"ERROR: file {fname} does not contain 'train_F'")
            best_F, psnr, idx = get_psnr(run, noisy_patches, clean_img, args.draw_in_dir, args.noisy_image)
        best_Fs.append(best_F)
        PSNRs.append(psnr)
        idxs.append(idx)

    for f, F, idx, psnr in zip(files, best_Fs, idxs, PSNRs):
        print(f"{psnr:5.2f}\t{F:10.2f}\t{idx:6}\t{f}")
    print()
